Skip to content

Conversation

@suexu1025
Copy link
Collaborator

@suexu1025 suexu1025 commented Oct 17, 2025

Description

Add config flag weight_sum_fp32 for whether to use full fp32 precision for weight_sum during final unpermute in moe

Tests

final eval loss at 300 steps
2.394 (cloudlog)(https://cloudlogging.app.goo.gl/Q5o2tac9aypGGMyV6)
2.393 (cloudlog)(https://cloudlogging.app.goo.gl/L2N43dAZiHap1Djk7)

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@suexu1025 suexu1025 requested a review from RissyRan October 17, 2025 18:54
@suexu1025 suexu1025 requested a review from RissyRan October 31, 2025 19:01
@github-actions
Copy link

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request introduces a new configuration flag, float32_weight_sum, to control the precision of the weighted sum operation in the Mixture of Experts (MoE) layers. The changes are well-implemented and provide useful flexibility for balancing performance and numerical precision.

🔍 General Feedback

  • The addition of the float32_weight_sum flag is a good feature for optimizing MoE layers.
  • The implementation in src/MaxText/layers/moe.py correctly applies the conditional casting based on the new configuration.
  • A minor style suggestion was made to improve comment consistency.

cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly.
float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product
float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax
float32_weight_sum: True # whether to use full fp32 precision for weight_sum during final unpermute in moe

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Nit: For consistency and clarity, it's better to use "MoE" instead of "moe".

Suggested change
float32_weight_sum: True # whether to use full fp32 precision for weight_sum during final unpermute in moe
float32_weight_sum: True # whether to use full fp32 precision for weight_sum during final unpermute in MoE

update

update

Update base.yml

Update moe.py

Update moe.py
@suexu1025 suexu1025 force-pushed the qinwen/add_up_quantize_config branch from 6f0349d to ee4e8cc Compare October 31, 2025 23:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants